{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/VideoMAE/Quick_inference_with_VideoMAE.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yE_-ROlu5FKf"
      },
      "source": [
        "## Set-up environment\n",
        "\n",
        "First, let's install 🤗 Transformers and decord, which we'll use to decode a video."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "75Zxzduq15cz",
        "outputId": "6db54d33-dd79-41bb-830a-01a897bbec5f"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
            "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
            "    Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n"
          ]
        }
      ],
      "source": [
        "!pip install -q git+https://github.com/huggingface/transformers.git"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "jqFJ4a_K3Am0"
      },
      "outputs": [],
      "source": [
        "!pip install -q decord"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1xz1lSAi2zcp"
      },
      "source": [
        "## Load video\n",
        "\n",
        "Let's load a video from the [Kinetics-400](https://www.deepmind.com/open-source/kinetics) dataset. This dataset contains millions of YouTube videos annotated with one out of 400 possible classes.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "c4ofrsAm6p54",
        "outputId": "88c7b0c1-c43a-4b52-e0d5-c6d1bed6eccc"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "--2022-08-04 16:06:14--  https://huggingface.co/datasets/nielsr/video-demo/resolve/main/eating_spaghetti.mp4\n",
            "Resolving huggingface.co (huggingface.co)... 34.231.117.252, 52.2.34.29, 2600:1f18:147f:e850:d57d:d46a:df34:61ee, ...\n",
            "Connecting to huggingface.co (huggingface.co)|34.231.117.252|:443... connected.\n",
            "HTTP request sent, awaiting response... 302 Found\n",
            "Location: https://cdn-lfs.huggingface.co/repos/21/27/2127ba3909eec39f0c04aa658b6aa97c12af51427ff415d000d565c97e36724b/252f63d13748f08acf56765c295506bfdb8bb73b822e93a33a57d73988814a71?response-content-disposition=attachment%3B%20filename%3D%22eating_spaghetti.mp4%22 [following]\n",
            "--2022-08-04 16:06:14--  https://cdn-lfs.huggingface.co/repos/21/27/2127ba3909eec39f0c04aa658b6aa97c12af51427ff415d000d565c97e36724b/252f63d13748f08acf56765c295506bfdb8bb73b822e93a33a57d73988814a71?response-content-disposition=attachment%3B%20filename%3D%22eating_spaghetti.mp4%22\n",
            "Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 13.224.167.90, 13.224.167.116, 13.224.167.3, ...\n",
            "Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|13.224.167.90|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 1013655 (990K) [video/mp4]\n",
            "Saving to: ‘eating_spaghetti.mp4.1’\n",
            "\n",
            "eating_spaghetti.mp 100%[===================>] 989.90K  1.05MB/s    in 0.9s    \n",
            "\n",
            "2022-08-04 16:06:16 (1.05 MB/s) - ‘eating_spaghetti.mp4.1’ saved [1013655/1013655]\n",
            "\n"
          ]
        }
      ],
      "source": [
        "!wget https://huggingface.co/datasets/nielsr/video-demo/resolve/main/eating_spaghetti.mp4"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 306,
          "referenced_widgets": [
            "0fccb7f9c27b428eb349bccb343e3466",
            "37a2e3f2b8f54f6e93a4a3f68477ba9b"
          ]
        },
        "id": "DFBLR2vGVK9Q",
        "outputId": "fffbdd5e-4a18-4612-f938-9629792eac59"
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Video(value=b'\\x00\\x00\\x00 ftypisom\\x00\\x00\\x02\\x00isomiso2avc1mp41\\x00\\x00\\x00\\x08free\\x00\\x0fI\\xb7mdat\\x00\\x…"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "0fccb7f9c27b428eb349bccb343e3466"
            }
          },
          "metadata": {}
        }
      ],
      "source": [
        "from ipywidgets import Video\n",
        "\n",
        "video_path = \"eating_spaghetti.mp4\" \n",
        "Video.from_file(video_path, width=500)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6LBAV-7u3cI6"
      },
      "source": [
        "## Prepare video for model\n",
        "\n",
        "We can prepare the video for the model using VideoMAEFeatureExtractor. We'll first sample 16 frames (out of the possible 300), and provide this to the feature extractor.\n",
        "\n",
        "It will perform some basic preprocessing, namely resize, center crop and normalize each frame of the video."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "eecFgurv3du0"
      },
      "outputs": [],
      "source": [
        "from transformers import VideoMAEFeatureExtractor\n",
        "\n",
        "feature_extractor = VideoMAEFeatureExtractor.from_pretrained(\"MCG-NJU/videomae-base-finetuned-kinetics\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "FqtN5Zv_3f4f",
        "outputId": "40fc6f93-ec8b-451f-d074-39744bdaf81e"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(16, 360, 640, 3)"
            ]
          },
          "metadata": {},
          "execution_count": 6
        }
      ],
      "source": [
        "from decord import VideoReader, cpu\n",
        "import numpy as np\n",
        "\n",
        "# video clip consists of 300 frames (10 seconds at 30 FPS)\n",
        "vr = VideoReader(video_path, num_threads=1, ctx=cpu(0)) \n",
        "\n",
        "def sample_frame_indices(clip_len, frame_sample_rate, seg_len):\n",
        "  converted_len = int(clip_len * frame_sample_rate)\n",
        "  end_idx = np.random.randint(converted_len, seg_len)\n",
        "  str_idx = end_idx - converted_len\n",
        "  index = np.linspace(str_idx, end_idx, num=clip_len)\n",
        "  index = np.clip(index, str_idx, end_idx - 1).astype(np.int64)\n",
        "  \n",
        "  return index\n",
        "\n",
        "vr.seek(0)\n",
        "index = sample_frame_indices(clip_len=16, frame_sample_rate=4, seg_len=len(vr))\n",
        "buffer = vr.get_batch(index).asnumpy()\n",
        "buffer.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "GDIn_eop4P4v",
        "outputId": "a5f057ec-d7d3-4f4b-d7af-2ce57d0aa494"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "torch.Size([1, 16, 3, 224, 224])\n"
          ]
        }
      ],
      "source": [
        "# create a list of NumPy arrays\n",
        "video = [buffer[i] for i in range(buffer.shape[0])]\n",
        "\n",
        "encoding = feature_extractor(video, return_tensors=\"pt\")\n",
        "print(encoding.pixel_values.shape)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eVNAP2Ra2MTs"
      },
      "source": [
        "## Load model\n",
        "\n",
        "Next, let's load the model and move it to the GPU, if it's available."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "1srziUwq2Dwu",
        "outputId": "54d84afb-9178-427e-a3d1-07abd11c59e8"
      },
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "VideoMAEForVideoClassification(\n",
              "  (videomae): VideoMAEModel(\n",
              "    (embeddings): VideoMAEEmbeddings(\n",
              "      (patch_embeddings): VideoMAEPatchEmbeddings(\n",
              "        (projection): Conv3d(3, 1024, kernel_size=(2, 16, 16), stride=(2, 16, 16))\n",
              "      )\n",
              "    )\n",
              "    (encoder): VideoMAEEncoder(\n",
              "      (layer): ModuleList(\n",
              "        (0): VideoMAELayer(\n",
              "          (attention): VideoMAEAttention(\n",
              "            (attention): VideoMAESelfAttention(\n",
              "              (query): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (key): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (value): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): VideoMAESelfOutput(\n",
              "              (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): VideoMAEIntermediate(\n",
              "            (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
              "            (intermediate_act_fn): GELUActivation()\n",
              "          )\n",
              "          (output): VideoMAEOutput(\n",
              "            (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
              "            (dropout): Dropout(p=0.0, inplace=False)\n",
              "          )\n",
              "          (layernorm_before): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "          (layernorm_after): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "        )\n",
              "        (1): VideoMAELayer(\n",
              "          (attention): VideoMAEAttention(\n",
              "            (attention): VideoMAESelfAttention(\n",
              "              (query): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (key): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (value): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): VideoMAESelfOutput(\n",
              "              (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): VideoMAEIntermediate(\n",
              "            (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
              "            (intermediate_act_fn): GELUActivation()\n",
              "          )\n",
              "          (output): VideoMAEOutput(\n",
              "            (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
              "            (dropout): Dropout(p=0.0, inplace=False)\n",
              "          )\n",
              "          (layernorm_before): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "          (layernorm_after): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "        )\n",
              "        (2): VideoMAELayer(\n",
              "          (attention): VideoMAEAttention(\n",
              "            (attention): VideoMAESelfAttention(\n",
              "              (query): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (key): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (value): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): VideoMAESelfOutput(\n",
              "              (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): VideoMAEIntermediate(\n",
              "            (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
              "            (intermediate_act_fn): GELUActivation()\n",
              "          )\n",
              "          (output): VideoMAEOutput(\n",
              "            (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
              "            (dropout): Dropout(p=0.0, inplace=False)\n",
              "          )\n",
              "          (layernorm_before): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "          (layernorm_after): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "        )\n",
              "        (3): VideoMAELayer(\n",
              "          (attention): VideoMAEAttention(\n",
              "            (attention): VideoMAESelfAttention(\n",
              "              (query): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (key): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (value): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): VideoMAESelfOutput(\n",
              "              (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): VideoMAEIntermediate(\n",
              "            (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
              "            (intermediate_act_fn): GELUActivation()\n",
              "          )\n",
              "          (output): VideoMAEOutput(\n",
              "            (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
              "            (dropout): Dropout(p=0.0, inplace=False)\n",
              "          )\n",
              "          (layernorm_before): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "          (layernorm_after): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "        )\n",
              "        (4): VideoMAELayer(\n",
              "          (attention): VideoMAEAttention(\n",
              "            (attention): VideoMAESelfAttention(\n",
              "              (query): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (key): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (value): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): VideoMAESelfOutput(\n",
              "              (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): VideoMAEIntermediate(\n",
              "            (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
              "            (intermediate_act_fn): GELUActivation()\n",
              "          )\n",
              "          (output): VideoMAEOutput(\n",
              "            (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
              "            (dropout): Dropout(p=0.0, inplace=False)\n",
              "          )\n",
              "          (layernorm_before): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "          (layernorm_after): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "        )\n",
              "        (5): VideoMAELayer(\n",
              "          (attention): VideoMAEAttention(\n",
              "            (attention): VideoMAESelfAttention(\n",
              "              (query): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (key): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (value): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): VideoMAESelfOutput(\n",
              "              (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): VideoMAEIntermediate(\n",
              "            (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
              "            (intermediate_act_fn): GELUActivation()\n",
              "          )\n",
              "          (output): VideoMAEOutput(\n",
              "            (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
              "            (dropout): Dropout(p=0.0, inplace=False)\n",
              "          )\n",
              "          (layernorm_before): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "          (layernorm_after): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "        )\n",
              "        (6): VideoMAELayer(\n",
              "          (attention): VideoMAEAttention(\n",
              "            (attention): VideoMAESelfAttention(\n",
              "              (query): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (key): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (value): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): VideoMAESelfOutput(\n",
              "              (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): VideoMAEIntermediate(\n",
              "            (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
              "            (intermediate_act_fn): GELUActivation()\n",
              "          )\n",
              "          (output): VideoMAEOutput(\n",
              "            (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
              "            (dropout): Dropout(p=0.0, inplace=False)\n",
              "          )\n",
              "          (layernorm_before): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "          (layernorm_after): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "        )\n",
              "        (7): VideoMAELayer(\n",
              "          (attention): VideoMAEAttention(\n",
              "            (attention): VideoMAESelfAttention(\n",
              "              (query): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (key): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (value): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): VideoMAESelfOutput(\n",
              "              (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): VideoMAEIntermediate(\n",
              "            (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
              "            (intermediate_act_fn): GELUActivation()\n",
              "          )\n",
              "          (output): VideoMAEOutput(\n",
              "            (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
              "            (dropout): Dropout(p=0.0, inplace=False)\n",
              "          )\n",
              "          (layernorm_before): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "          (layernorm_after): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "        )\n",
              "        (8): VideoMAELayer(\n",
              "          (attention): VideoMAEAttention(\n",
              "            (attention): VideoMAESelfAttention(\n",
              "              (query): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (key): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (value): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): VideoMAESelfOutput(\n",
              "              (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): VideoMAEIntermediate(\n",
              "            (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
              "            (intermediate_act_fn): GELUActivation()\n",
              "          )\n",
              "          (output): VideoMAEOutput(\n",
              "            (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
              "            (dropout): Dropout(p=0.0, inplace=False)\n",
              "          )\n",
              "          (layernorm_before): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "          (layernorm_after): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "        )\n",
              "        (9): VideoMAELayer(\n",
              "          (attention): VideoMAEAttention(\n",
              "            (attention): VideoMAESelfAttention(\n",
              "              (query): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (key): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (value): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): VideoMAESelfOutput(\n",
              "              (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): VideoMAEIntermediate(\n",
              "            (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
              "            (intermediate_act_fn): GELUActivation()\n",
              "          )\n",
              "          (output): VideoMAEOutput(\n",
              "            (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
              "            (dropout): Dropout(p=0.0, inplace=False)\n",
              "          )\n",
              "          (layernorm_before): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "          (layernorm_after): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "        )\n",
              "        (10): VideoMAELayer(\n",
              "          (attention): VideoMAEAttention(\n",
              "            (attention): VideoMAESelfAttention(\n",
              "              (query): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (key): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (value): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): VideoMAESelfOutput(\n",
              "              (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): VideoMAEIntermediate(\n",
              "            (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
              "            (intermediate_act_fn): GELUActivation()\n",
              "          )\n",
              "          (output): VideoMAEOutput(\n",
              "            (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
              "            (dropout): Dropout(p=0.0, inplace=False)\n",
              "          )\n",
              "          (layernorm_before): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "          (layernorm_after): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "        )\n",
              "        (11): VideoMAELayer(\n",
              "          (attention): VideoMAEAttention(\n",
              "            (attention): VideoMAESelfAttention(\n",
              "              (query): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (key): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (value): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): VideoMAESelfOutput(\n",
              "              (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): VideoMAEIntermediate(\n",
              "            (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
              "            (intermediate_act_fn): GELUActivation()\n",
              "          )\n",
              "          (output): VideoMAEOutput(\n",
              "            (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
              "            (dropout): Dropout(p=0.0, inplace=False)\n",
              "          )\n",
              "          (layernorm_before): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "          (layernorm_after): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "        )\n",
              "        (12): VideoMAELayer(\n",
              "          (attention): VideoMAEAttention(\n",
              "            (attention): VideoMAESelfAttention(\n",
              "              (query): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (key): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (value): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): VideoMAESelfOutput(\n",
              "              (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): VideoMAEIntermediate(\n",
              "            (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
              "            (intermediate_act_fn): GELUActivation()\n",
              "          )\n",
              "          (output): VideoMAEOutput(\n",
              "            (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
              "            (dropout): Dropout(p=0.0, inplace=False)\n",
              "          )\n",
              "          (layernorm_before): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "          (layernorm_after): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "        )\n",
              "        (13): VideoMAELayer(\n",
              "          (attention): VideoMAEAttention(\n",
              "            (attention): VideoMAESelfAttention(\n",
              "              (query): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (key): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (value): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): VideoMAESelfOutput(\n",
              "              (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): VideoMAEIntermediate(\n",
              "            (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
              "            (intermediate_act_fn): GELUActivation()\n",
              "          )\n",
              "          (output): VideoMAEOutput(\n",
              "            (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
              "            (dropout): Dropout(p=0.0, inplace=False)\n",
              "          )\n",
              "          (layernorm_before): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "          (layernorm_after): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "        )\n",
              "        (14): VideoMAELayer(\n",
              "          (attention): VideoMAEAttention(\n",
              "            (attention): VideoMAESelfAttention(\n",
              "              (query): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (key): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (value): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): VideoMAESelfOutput(\n",
              "              (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): VideoMAEIntermediate(\n",
              "            (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
              "            (intermediate_act_fn): GELUActivation()\n",
              "          )\n",
              "          (output): VideoMAEOutput(\n",
              "            (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
              "            (dropout): Dropout(p=0.0, inplace=False)\n",
              "          )\n",
              "          (layernorm_before): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "          (layernorm_after): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "        )\n",
              "        (15): VideoMAELayer(\n",
              "          (attention): VideoMAEAttention(\n",
              "            (attention): VideoMAESelfAttention(\n",
              "              (query): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (key): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (value): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): VideoMAESelfOutput(\n",
              "              (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): VideoMAEIntermediate(\n",
              "            (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
              "            (intermediate_act_fn): GELUActivation()\n",
              "          )\n",
              "          (output): VideoMAEOutput(\n",
              "            (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
              "            (dropout): Dropout(p=0.0, inplace=False)\n",
              "          )\n",
              "          (layernorm_before): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "          (layernorm_after): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "        )\n",
              "        (16): VideoMAELayer(\n",
              "          (attention): VideoMAEAttention(\n",
              "            (attention): VideoMAESelfAttention(\n",
              "              (query): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (key): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (value): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): VideoMAESelfOutput(\n",
              "              (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): VideoMAEIntermediate(\n",
              "            (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
              "            (intermediate_act_fn): GELUActivation()\n",
              "          )\n",
              "          (output): VideoMAEOutput(\n",
              "            (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
              "            (dropout): Dropout(p=0.0, inplace=False)\n",
              "          )\n",
              "          (layernorm_before): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "          (layernorm_after): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "        )\n",
              "        (17): VideoMAELayer(\n",
              "          (attention): VideoMAEAttention(\n",
              "            (attention): VideoMAESelfAttention(\n",
              "              (query): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (key): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (value): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): VideoMAESelfOutput(\n",
              "              (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): VideoMAEIntermediate(\n",
              "            (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
              "            (intermediate_act_fn): GELUActivation()\n",
              "          )\n",
              "          (output): VideoMAEOutput(\n",
              "            (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
              "            (dropout): Dropout(p=0.0, inplace=False)\n",
              "          )\n",
              "          (layernorm_before): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "          (layernorm_after): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "        )\n",
              "        (18): VideoMAELayer(\n",
              "          (attention): VideoMAEAttention(\n",
              "            (attention): VideoMAESelfAttention(\n",
              "              (query): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (key): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (value): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): VideoMAESelfOutput(\n",
              "              (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): VideoMAEIntermediate(\n",
              "            (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
              "            (intermediate_act_fn): GELUActivation()\n",
              "          )\n",
              "          (output): VideoMAEOutput(\n",
              "            (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
              "            (dropout): Dropout(p=0.0, inplace=False)\n",
              "          )\n",
              "          (layernorm_before): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "          (layernorm_after): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "        )\n",
              "        (19): VideoMAELayer(\n",
              "          (attention): VideoMAEAttention(\n",
              "            (attention): VideoMAESelfAttention(\n",
              "              (query): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (key): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (value): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): VideoMAESelfOutput(\n",
              "              (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): VideoMAEIntermediate(\n",
              "            (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
              "            (intermediate_act_fn): GELUActivation()\n",
              "          )\n",
              "          (output): VideoMAEOutput(\n",
              "            (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
              "            (dropout): Dropout(p=0.0, inplace=False)\n",
              "          )\n",
              "          (layernorm_before): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "          (layernorm_after): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "        )\n",
              "        (20): VideoMAELayer(\n",
              "          (attention): VideoMAEAttention(\n",
              "            (attention): VideoMAESelfAttention(\n",
              "              (query): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (key): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (value): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): VideoMAESelfOutput(\n",
              "              (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): VideoMAEIntermediate(\n",
              "            (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
              "            (intermediate_act_fn): GELUActivation()\n",
              "          )\n",
              "          (output): VideoMAEOutput(\n",
              "            (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
              "            (dropout): Dropout(p=0.0, inplace=False)\n",
              "          )\n",
              "          (layernorm_before): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "          (layernorm_after): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "        )\n",
              "        (21): VideoMAELayer(\n",
              "          (attention): VideoMAEAttention(\n",
              "            (attention): VideoMAESelfAttention(\n",
              "              (query): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (key): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (value): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): VideoMAESelfOutput(\n",
              "              (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): VideoMAEIntermediate(\n",
              "            (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
              "            (intermediate_act_fn): GELUActivation()\n",
              "          )\n",
              "          (output): VideoMAEOutput(\n",
              "            (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
              "            (dropout): Dropout(p=0.0, inplace=False)\n",
              "          )\n",
              "          (layernorm_before): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "          (layernorm_after): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "        )\n",
              "        (22): VideoMAELayer(\n",
              "          (attention): VideoMAEAttention(\n",
              "            (attention): VideoMAESelfAttention(\n",
              "              (query): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (key): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (value): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): VideoMAESelfOutput(\n",
              "              (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): VideoMAEIntermediate(\n",
              "            (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
              "            (intermediate_act_fn): GELUActivation()\n",
              "          )\n",
              "          (output): VideoMAEOutput(\n",
              "            (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
              "            (dropout): Dropout(p=0.0, inplace=False)\n",
              "          )\n",
              "          (layernorm_before): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "          (layernorm_after): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "        )\n",
              "        (23): VideoMAELayer(\n",
              "          (attention): VideoMAEAttention(\n",
              "            (attention): VideoMAESelfAttention(\n",
              "              (query): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (key): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (value): Linear(in_features=1024, out_features=1024, bias=False)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "            (output): VideoMAESelfOutput(\n",
              "              (dense): Linear(in_features=1024, out_features=1024, bias=True)\n",
              "              (dropout): Dropout(p=0.0, inplace=False)\n",
              "            )\n",
              "          )\n",
              "          (intermediate): VideoMAEIntermediate(\n",
              "            (dense): Linear(in_features=1024, out_features=4096, bias=True)\n",
              "            (intermediate_act_fn): GELUActivation()\n",
              "          )\n",
              "          (output): VideoMAEOutput(\n",
              "            (dense): Linear(in_features=4096, out_features=1024, bias=True)\n",
              "            (dropout): Dropout(p=0.0, inplace=False)\n",
              "          )\n",
              "          (layernorm_before): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "          (layernorm_after): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)\n",
              "        )\n",
              "      )\n",
              "    )\n",
              "  )\n",
              "  (fc_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n",
              "  (classifier): Linear(in_features=1024, out_features=400, bias=True)\n",
              ")"
            ]
          },
          "metadata": {},
          "execution_count": 8
        }
      ],
      "source": [
        "from transformers import VideoMAEFeatureExtractor, VideoMAEForVideoClassification\n",
        "import torch\n",
        "\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "\n",
        "model = VideoMAEForVideoClassification.from_pretrained(\"MCG-NJU/videomae-large-finetuned-kinetics\")\n",
        "model.to(device)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "u6cFXNQMWtIl"
      },
      "source": [
        "## Forward pass"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "u1grikRn2Lye"
      },
      "outputs": [],
      "source": [
        "pixel_values = encoding.pixel_values.to(device)\n",
        "\n",
        "# forward pass\n",
        "with torch.no_grad():\n",
        "  outputs = model(pixel_values)\n",
        "  logits = outputs.logits"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "VcNjRjTd4eQA",
        "outputId": "82dc1561-20f7-4ebb-c494-3fa42f846f0c"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Predicted class: eating spaghetti\n"
          ]
        }
      ],
      "source": [
        "predicted_class_idx = logits.argmax(-1).item()\n",
        "\n",
        "print(\"Predicted class:\", model.config.id2label[predicted_class_idx])"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "Quick inference with VideoMAE.ipynb",
      "provenance": [],
      "mount_file_id": "1ZX_XnM0ol81FbcxrFS3nNLkmn-0fzvQk",
      "authorship_tag": "ABX9TyPNOagUmRaW/u3pa28X/KSJ",
      "include_colab_link": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    },
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "0fccb7f9c27b428eb349bccb343e3466": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "VideoModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "VideoModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "VideoView",
            "autoplay": true,
            "controls": true,
            "format": "mp4",
            "height": "",
            "layout": "IPY_MODEL_37a2e3f2b8f54f6e93a4a3f68477ba9b",
            "loop": true,
            "width": "500"
          }
        },
        "37a2e3f2b8f54f6e93a4a3f68477ba9b": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        }
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}